Skip to content

Fix explicit-mesh sharding assert in deepseek batch-split scan#4208

Draft
ecnal-cienet wants to merge 1 commit into
mainfrom
fix/deepseek-batchsplit-context-reshard
Draft

Fix explicit-mesh sharding assert in deepseek batch-split scan#4208
ecnal-cienet wants to merge 1 commit into
mainfrom
fix/deepseek-batchsplit-context-reshard

Conversation

@ecnal-cienet

@ecnal-cienet ecnal-cienet commented Jun 19, 2026

Copy link
Copy Markdown
Collaborator

Problem

Training deepseek3-671b-batchsplit (which sets shard_mode=explicit and use_batch_split_schedule=true) crashes at the first train_step compile:

AssertionError: `with_sharding_constraint` acts as an assert when all axes of
mesh are of type Explicit. The array sharding: P(('data', 'fsdp', 'expert'), None)
did not match the sharding provided: P(('data', 'fsdp', 'expert', 'context'), None).

Root cause

Under JAX's explicit mesh axes, with_logical_constraint (used in loss_fn) is a hard assert that the array's sharding matches exactly — it no longer silently reshards as it did under auto axes.

scan_batch_split_layers reshards activations internally to P(('data','fsdp','expert'), None, None), deliberately dropping the context axis (size 1 here) for the manual split/merge collectives, and returns the hidden states still sharded that way. loss_fn then constrains xent/z_loss via activation_embed_and_logits_batch, which maps to ('data','fsdp','expert','context'). The two specs differ by the context axis, so the assert fires.

Fix

Capture the incoming sharding on entry and jax.reshard the output back to it before returning, so the batch-split path is transparent to downstream constraints. This mirrors what the per-layer DeepSeekDecoderLayer.__call__ path already does (input_sharding = jax.typeof(inputs).shardingjax.reshard(outputs, input_sharding)); scan_batch_split_layers was simply missing the restore step. Since context has size 1 in this config, this is the same physical layout — only the spec is made to line up for the explicit mesh.

Tests

Before Fix (Linen, passed):
After Fix (Linen, passed):
Before Fix (NNX, failed): https://cloudlogging.app.goo.gl/XF3eVFANEPx43XBG7
After Fix (NNX, passed):

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

@codecov

codecov Bot commented Jun 19, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 0% with 2 lines in your changes missing coverage. Please review.

Files with missing lines Patch % Lines
src/maxtext/trainers/pre_train/train.py 0.00% 2 Missing ⚠️

📢 Thoughts on this report? Let us know!

loss_fn's NNX branch constrained xent/z_loss with raw
nn.with_logical_constraint, which keeps the size-1 context axis, while
the model shards activations via create_sharding (remove_size_one_mesh_axis
drops context). Under an explicit mesh the mismatch is a hard assert.

Use sharding.maybe_shard_with_logical, as the Linen branch already does:
it builds the sharding via create_sharding (dropping size-1 context, so it
matches the array) and reshards instead of asserting under explicit mode.
@ecnal-cienet ecnal-cienet force-pushed the fix/deepseek-batchsplit-context-reshard branch from 1beca40 to 21b80d3 Compare June 19, 2026 21:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant